import sys
import os
import argparse
import yaml
import json
import time
import asyncio
from pathlib import Path
import torch
import torch.nn.functional as F
import copy
from typing import List,Union,Literal
import random
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.stdout.reconfigure(encoding='utf-8')

from AgentDropout.utils.const import AgentPrune_ROOT
from AgentDropout.graph.graph import Graph
from AgentDropout.tools.reader.readers import JSONReader, JSONLReader
from AgentDropout.utils.globals import Time
from AgentDropout.utils.globals import Cost, PromptTokens, CompletionTokens
from AgentDropout.utils.utils import nuclear_norm,frobenius_norm
from datasets.gsm8k_dataset import svamp_data_process,gsm_get_predict, gsm_data_process,multiarith_data_process
from datasets.aqua_dataset import aqua_data_process,aqua_get_predict
from AgentDropout.utils.globals import PromptTokens, CompletionTokens
from AgentDropout.agents.agent_registry import AgentRegistry

def load_result(result_file):
    if not result_file.exists():
        with open(result_file, 'w',encoding='utf-8') as file:
            json.dump([], file)

    with open(result_file, 'r',encoding='utf-8') as file:
        data = json.load(file)
    return data

def dataloader(data_list, batch_size, i_batch):
    return data_list[i_batch*batch_size:i_batch*batch_size + batch_size]

def load_config(config_path):
    with open(config_path, 'r',encoding='utf-8') as file:
        return yaml.safe_load(file)
    
def parse_args():
    parser = argparse.ArgumentParser(description="Experiments on gsm8k")
    parser.add_argument("--dataset_json", type=str, default="datasets/SVAMP/test.json")
    parser.add_argument("--result_file", type=str, default=None)
    parser.add_argument("--llm_name", type=str, default="gpt-3.5-turbo")
    parser.add_argument('--mode', type=str, default='FullConnected',
                        choices=['DirectAnswer', 'FullConnected', 'Random', 'Chain','Debate','Layered','Star'],
                        help="Mode of operation. Default is 'FullConnected'.")
    parser.add_argument('--lr', type=float, default=0.1,help="learning rate")
    parser.add_argument('--delta', type=float, default=0.1, help="noise level")
    parser.add_argument('--batch_size', type=int, default=4,help="batch size")
    parser.add_argument('--imp_per_iterations', type=int, default=5, help="Prune every few iterations. Default 5.")
    parser.add_argument('--num_rounds',type=int,default=1,help="Number of optimization/inference rounds for one query")
    parser.add_argument('--pruning_rate', type=float, default=0.25,help="The Rate of Pruning. Default 0.05.")
    parser.add_argument('--num_iterations', type=int, default=10,help="The num of training iterations.")
    parser.add_argument('--domain', type=str, default="gsm8k",help="Domain (the same as dataset name), default 'gsm8k'")
    parser.add_argument('--agent_names', nargs='+', type=str, default=['MathSolver'],
                        help='Specify agent names as a list of strings')
    parser.add_argument('--agent_nums', nargs='+', type=int, default=[4],
                        help='Specify the number of agents for each name in agent_names')
    parser.add_argument('--decision_method', type=str, default='FinalRefer',
                        help='The decison method of the agentprune')
    parser.add_argument('--optimized_spatial',action='store_true')
    parser.add_argument('--optimized_temporal',action='store_true')
    parser.add_argument('--diff',action='store_true')
    parser.add_argument('--dec',action='store_true')
    parser.add_argument('--cot',action='store_true')
    args = parser.parse_args()
    result_path = AgentPrune_ROOT / "result"
    os.makedirs(result_path, exist_ok=True)
    if len(args.agent_names) != len(args.agent_nums):
        parser.error("The number of agent names must match the number of agent counts.")

    return args

async def main():
    args = parse_args()
    result_file = None
    dataset = JSONReader.parse_file(args.dataset_json)
    dataset = svamp_data_process(dataset)
    train_dataset = JSONReader.parse_file('datasets/SVAMP/train.json')
    train_dataset = svamp_data_process(train_dataset)

    current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    Time.instance().value = current_time
    result_dir = Path(f"{AgentPrune_ROOT}/result/SVAMP")
    result_dir.mkdir(parents=True, exist_ok=True)
    result_file = result_dir / f"{args.domain}_llama3_{current_time}.json"
    
    agent_names = [name for name,num in zip(args.agent_names,args.agent_nums) for _ in range(num)]
    decision_method = args.decision_method
    kwargs = get_kwargs(args.mode,len(agent_names))

    graph = Graph(domain="gsm8k",
                    llm_name=args.llm_name,
                    agent_names=agent_names,
                    decision_method=decision_method,
                    optimized_spatial=args.optimized_spatial,
                    optimized_temporal=args.optimized_temporal,
                    rounds=args.num_rounds,
                    diff=args.diff,
                    dec=args.dec,
                    **kwargs)
    
    if args.dec:
        graph.optimized_spatial=False
        graph.optimized_temporal=False
        total_solved, total_executed = (0, 0)
        if not graph.diff:
            optimizer = torch.optim.Adam([graph.spatial_logits_1,graph.temporal_logits_1], lr=args.lr)
        else:
            optimizer = torch.optim.Adam(list(graph.spatial_logits_1.parameters()) + list(graph.temporal_logits_1.parameters()),lr=args.lr)
        for i_batch in range(args.num_iterations):
            print(f"Train batch {i_batch}",80*'-')
            start_ts = time.time()
            answer_log_probs = []
            answers = []
            add_losses = []
            
            current_batch = dataloader(train_dataset,20,i_batch)
            if current_batch is None:
                print("No more data available.")
                break
            
            for i_record, record in enumerate(current_batch):
                realized_graph = copy.deepcopy(graph)
                realized_graph.spatial_logits_1 = graph.spatial_logits_1
                realized_graph.temporal_logits_1 = graph.temporal_logits_1
                
                if not graph.diff:
                    spatial_matrix_train = realized_graph.spatial_logits_1.reshape((sum(args.agent_nums),sum(args.agent_nums)))
                    temporal_matrix_train = realized_graph.temporal_logits_1.reshape((sum(args.agent_nums),sum(args.agent_nums)))
                else:
                    spatial_matrix_train = [param.reshape((sum(args.agent_nums), sum(args.agent_nums))) for param in realized_graph.spatial_logits_1]
                    temporal_matrix_train = [param.reshape((sum(args.agent_nums), sum(args.agent_nums))) for param in realized_graph.temporal_logits_1]
                spatial_matrix_fixed = torch.tensor(kwargs["fixed_spatial_masks"],dtype=torch.float32).reshape((len(agent_names),len(agent_names)))
                temporal_matrix_fixed = torch.tensor(kwargs["fixed_temporal_masks"],dtype=torch.float32).reshape((len(agent_names),len(agent_names)))
                if not graph.diff:
                    loss_s = nuclear_norm(spatial_matrix_train)
                    loss_t = nuclear_norm(temporal_matrix_train)
                    frob_loss_s = frobenius_norm(spatial_matrix_fixed, spatial_matrix_train)
                    frob_loss_t = frobenius_norm(temporal_matrix_fixed, temporal_matrix_train)
                else:
                    # loss_s = sum(nuclear_norm(matrix) for matrix in spatial_matrix_train)
                    # loss_t = sum(nuclear_norm(matrix) for matrix in temporal_matrix_train)
                    # frob_loss_s = sum(frobenius_norm(spatial_matrix_fixed, matrix) for matrix in spatial_matrix_train)
                    # frob_loss_t = sum(frobenius_norm(temporal_matrix_fixed, matrix) for matrix in temporal_matrix_train)
                    loss_s = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in spatial_matrix_train]))
                    loss_t = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in temporal_matrix_train]))
                    frob_loss_s = torch.mean(torch.stack([frobenius_norm(spatial_matrix_fixed, matrix) for matrix in spatial_matrix_train]))
                    frob_loss_t = torch.mean(torch.stack([frobenius_norm(temporal_matrix_fixed, matrix) for matrix in temporal_matrix_train]))
                add_loss = loss_s + loss_t + F.relu(frob_loss_s - args.delta) + F.relu(frob_loss_t - args.delta)
                add_loss=0
                task = record["task"]
                step = record["step"]
                answer = record["answer"]
                answers.append(answer)
                input_dict = {"task": task}
                answer_log_probs.append(asyncio.create_task(realized_graph.arun(input_dict,args.num_rounds,skip=True)))
                add_losses.append(add_loss)
                
            raw_results = await asyncio.gather(*answer_log_probs)
            raw_answers, log_probs = zip(*raw_results)
            loss_list: List[torch.Tensor] = []
            utilities: List[float] = []
            data = load_result(result_file)
            
            for task, answer, log_prob, add_loss, true_answer in zip(current_batch, raw_answers, log_probs, add_losses, answers):
                predict_answer = gsm_get_predict(answer[0])
                is_solved = float(predict_answer)==float(true_answer)
                total_solved = total_solved + is_solved
                total_executed = total_executed + 1
                accuracy = total_solved/ total_executed
                utility = is_solved
                utilities.append(utility)
                single_loss = -log_prob * utility
                loss_list.append(single_loss+add_loss)
                updated_item = {
                    "Question": task,
                    "Answer": true_answer,
                    "Step": step,
                    "Response": answer,
                    "Attempt answer": predict_answer,
                    "Solved": is_solved,
                    "Total solved": total_solved,
                    "Total executed": total_executed,
                    "Accuracy": accuracy
                }
                # data.append(updated_item)
                print(f"##########Final Log:{json.dumps(updated_item)}")
            with open(result_file, 'w',encoding='utf-8') as file:
                json.dump(data, file, indent=4)
            
            total_loss = torch.mean(torch.stack(loss_list))
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            if not graph.diff:
                spatial_probs = torch.sigmoid(graph.spatial_logits_1)
                temporal_probs = torch.sigmoid(graph.temporal_logits_1)
            else:
                spatial_probs = [torch.sigmoid(logit) for logit in graph.spatial_logits_1]
                temporal_probs = [torch.sigmoid(logit) for logit in graph.temporal_logits_1]
            
            print(f"Batch time {time.time() - start_ts:.3f}")
            print(f"Accuracy: {accuracy}")
            print("utilities:", utilities)
            print("loss:", total_loss.item())
            # print("Spatial logits Grad:", graph.spatial_logits.grad)
            # print("Temporal logits Grad:", graph.spatial_logits.grad)
            print("Spatial logits:", graph.spatial_logits_1)
            print("Temporal logits:", graph.temporal_logits_1)
            print("Spatial probs:", spatial_probs)
            print("Temporal probs:", temporal_probs)
            print("Spatial masks:", graph.spatial_masks)
            print("Temporal logits:", graph.temporal_masks)
            
            if (i_batch+1)%args.imp_per_iterations == 0 and i_batch < args.num_iterations and (args.optimized_spatial or args.optimized_temporal):
                if not graph.diff:
                    print("spatial sparsity:",graph.spatial_masks.sum()/graph.spatial_masks.numel())
                    print("temporal sparsity:",graph.temporal_masks.sum()/graph.temporal_masks.numel())
                else:
                    print("spatial sparsity:",graph.spatial_masks[0].sum()/graph.spatial_masks[0].numel())
                    print("temporal sparsity:",graph.temporal_masks[0].sum()/graph.temporal_masks[0].numel())
            print(f"Cost {Cost.instance().value}")
            print(f"PromptTokens {PromptTokens.instance().value}")
            print(f"CompletionTokens {CompletionTokens.instance().value}")
        graph.update_masks_dec()

    if not graph.diff:
        optimizer = torch.optim.Adam([graph.spatial_logits,graph.temporal_logits], lr=args.lr)    
    else:
        optimizer = torch.optim.Adam(list(graph.spatial_logits.parameters()) + list(graph.temporal_logits.parameters()),lr=args.lr)  
    
    num_batches = int(len(dataset)/args.batch_size)
    total_solved, total_executed = (0, 0)
    
    
    if args.optimized_temporal or args.optimized_spatial:
        # graph.optimized_spatial=True
        # graph.optimized_temporal=True
        for i_batch in range(0):
            print(f"Train batch {i_batch}",80*'-')
            start_ts = time.time()
            answer_log_probs = []
            answers = []
            add_losses = []
            
            current_batch = dataloader(train_dataset,10,i_batch)
            if current_batch is None:
                print("No more data available.")
                break
            
            for i_record, record in enumerate(current_batch):
                realized_graph = copy.deepcopy(graph)
                realized_graph.spatial_logits = graph.spatial_logits
                realized_graph.temporal_logits = graph.temporal_logits
                
                if not graph.diff:
                    spatial_matrix_train = realized_graph.spatial_logits.reshape((sum(args.agent_nums),sum(args.agent_nums)))
                    temporal_matrix_train = realized_graph.temporal_logits.reshape((sum(args.agent_nums),sum(args.agent_nums)))
                else:
                    spatial_matrix_train = [param.reshape((sum(args.agent_nums), sum(args.agent_nums))) for param in realized_graph.spatial_logits]
                    temporal_matrix_train = [param.reshape((sum(args.agent_nums), sum(args.agent_nums))) for param in realized_graph.temporal_logits]
                spatial_matrix_fixed = torch.tensor(kwargs["fixed_spatial_masks"],dtype=torch.float32).reshape((len(agent_names),len(agent_names)))
                temporal_matrix_fixed = torch.tensor(kwargs["fixed_temporal_masks"],dtype=torch.float32).reshape((len(agent_names),len(agent_names)))
                if not graph.diff:
                    loss_s = nuclear_norm(spatial_matrix_train)
                    loss_t = nuclear_norm(temporal_matrix_train)
                    frob_loss_s = frobenius_norm(spatial_matrix_fixed, spatial_matrix_train)
                    frob_loss_t = frobenius_norm(temporal_matrix_fixed, temporal_matrix_train)
                else:
                    # loss_s = sum(nuclear_norm(matrix) for matrix in spatial_matrix_train)
                    # loss_t = sum(nuclear_norm(matrix) for matrix in temporal_matrix_train)
                    # frob_loss_s = sum(frobenius_norm(spatial_matrix_fixed, matrix) for matrix in spatial_matrix_train)
                    # frob_loss_t = sum(frobenius_norm(temporal_matrix_fixed, matrix) for matrix in temporal_matrix_train)
                    loss_s = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in spatial_matrix_train]))
                    loss_t = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in temporal_matrix_train]))
                    frob_loss_s = torch.mean(torch.stack([frobenius_norm(spatial_matrix_fixed, matrix) for matrix in spatial_matrix_train]))
                    frob_loss_t = torch.mean(torch.stack([frobenius_norm(temporal_matrix_fixed, matrix) for matrix in temporal_matrix_train]))
                add_loss = loss_s + loss_t + F.relu(frob_loss_s - args.delta) + F.relu(frob_loss_t - args.delta)
                
                task = record["task"]
                step = record["step"]
                answer = record["answer"]
                answers.append(answer)
                input_dict = {"task": task}
                answer_log_probs.append(asyncio.create_task(realized_graph.arun(input_dict,args.num_rounds)))
                add_losses.append(add_loss)
                
            raw_results = await asyncio.gather(*answer_log_probs)
            raw_answers, log_probs = zip(*raw_results)
            loss_list: List[torch.Tensor] = []
            utilities: List[float] = []
            data = load_result(result_file)
            
            for task, answer, log_prob, add_loss, true_answer in zip(current_batch, raw_answers, log_probs, add_losses, answers):
                predict_answer = gsm_get_predict(answer[0])
                is_solved = float(predict_answer)==float(true_answer)
                total_solved = total_solved + is_solved
                total_executed = total_executed + 1
                accuracy = total_solved/ total_executed
                utility = is_solved
                utilities.append(utility)
                single_loss = -log_prob * utility
                loss_list.append(single_loss+add_loss)
                updated_item = {
                    "Question": task,
                    "Answer": true_answer,
                    "Step": step,
                    "Response": answer,
                    "Attempt answer": predict_answer,
                    "Solved": is_solved,
                    "Total solved": total_solved,
                    "Total executed": total_executed,
                    "Accuracy": accuracy
                }
                # data.append(updated_item)
                print(f"##########Final Log:{json.dumps(updated_item)}")
            with open(result_file, 'w',encoding='utf-8') as file:
                json.dump(data, file, indent=4)
            
            total_loss = torch.mean(torch.stack(loss_list))
            if args.optimized_spatial or args.optimized_temporal:
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
            if not graph.diff:
                spatial_probs = torch.sigmoid(graph.spatial_logits)
                temporal_probs = torch.sigmoid(graph.temporal_logits)
            else:
                spatial_probs = [torch.sigmoid(logit) for logit in graph.spatial_logits]
                temporal_probs = [torch.sigmoid(logit) for logit in graph.temporal_logits]
            
            print(f"Batch time {time.time() - start_ts:.3f}")
            print(f"Accuracy: {accuracy}")
            print("utilities:", utilities)
            print("loss:", total_loss.item())
            # print("Spatial logits Grad:", graph.spatial_logits.grad)
            # print("Temporal logits Grad:", graph.spatial_logits.grad)
            print("Spatial logits:", graph.spatial_logits)
            print("Temporal logits:", graph.temporal_logits)
            print("Spatial probs:", spatial_probs)
            print("Temporal probs:", temporal_probs)
            print("Spatial masks:", graph.spatial_masks)
            print("Temporal logits:", graph.temporal_masks)
            
            if (i_batch+1)%2 == 0 and i_batch < 4 and (args.optimized_spatial or args.optimized_temporal):
                if not graph.diff:
                    spatial_masks, temporal_masks = graph.update_masks(args.pruning_rate)
                else:
                    spatial_masks, temporal_masks = graph.update_masks_diff(args.pruning_rate)
                print("spatial masks:",spatial_masks)
                print("temporal masks:",temporal_masks)
                if not graph.diff:
                    print("spatial sparsity:",spatial_masks.sum()/spatial_masks.numel())
                    print("temporal sparsity:",temporal_masks.sum()/temporal_masks.numel())
                else:
                    print("spatial sparsity:",spatial_masks[0].sum()/spatial_masks[0].numel())
                    print("temporal sparsity:",temporal_masks[0].sum()/temporal_masks[0].numel())
            if i_batch+1 == args.num_iterations:
                args.optimized_spatial = False
                args.optimized_temporal = False
            print(f"Cost {Cost.instance().value}")
            print(f"PromptTokens {PromptTokens.instance().value}")
            print(f"CompletionTokens {CompletionTokens.instance().value}")

    PromptTokens.instance().reset()
    CompletionTokens.instance().reset()
    total_solved, total_executed = (0, 0)

    # graph.clear_spatial_connection()
    # graph.clear_temporal_connection()
    # graph.domain='aqua'
    # graph.nodes = {}
    # print(graph.nodes)
    # graph.agent_names = ['MathSolver_aqua']
    # graph.agent_names = [name for name,num in zip(['MathSolver_aqua'],args.agent_nums) for _ in range(num)]
    # graph.node_kwargs = [{} for _ in graph.agent_names]
    # graph.init_nodes()
    # print(graph.nodes)
    # graph.potential_spatial_edges = []
    # graph.potential_temporal_edges = []
    # graph.init_potential_edges()
    # graph.decision_node = AgentRegistry.get(args.decision_method, **{"domain":'aqua',"llm_name":args.llm_name})

    for i_batch in range(num_batches):
        print(f"Batch {i_batch}",80*'-')
        start_ts = time.time()
        answer_log_probs = []
        answers = []
        add_losses = []
        
        current_batch = dataloader(dataset,args.batch_size,i_batch)
        if current_batch is None:
            print("No more data available.")
            break
        
        print(11111111)
        for i_record, record in enumerate(current_batch):
            realized_graph = copy.deepcopy(graph)
            realized_graph.spatial_logits = graph.spatial_logits
            realized_graph.temporal_logits = graph.temporal_logits
            
            if not graph.diff:
                spatial_matrix_train = realized_graph.spatial_logits.reshape((sum(args.agent_nums),sum(args.agent_nums)))
                temporal_matrix_train = realized_graph.temporal_logits.reshape((sum(args.agent_nums),sum(args.agent_nums)))
            else:
                spatial_matrix_train = [param.reshape((sum(args.agent_nums), sum(args.agent_nums))) for param in realized_graph.spatial_logits]
                temporal_matrix_train = [param.reshape((sum(args.agent_nums), sum(args.agent_nums))) for param in realized_graph.temporal_logits]
            spatial_matrix_fixed = torch.tensor(kwargs["fixed_spatial_masks"],dtype=torch.float32).reshape((len(agent_names),len(agent_names)))
            temporal_matrix_fixed = torch.tensor(kwargs["fixed_temporal_masks"],dtype=torch.float32).reshape((len(agent_names),len(agent_names)))
            if not graph.diff:
                loss_s = nuclear_norm(spatial_matrix_train)
                loss_t = nuclear_norm(temporal_matrix_train)
                frob_loss_s = frobenius_norm(spatial_matrix_fixed, spatial_matrix_train)
                frob_loss_t = frobenius_norm(temporal_matrix_fixed, temporal_matrix_train)
            else:
                # loss_s = sum(nuclear_norm(matrix) for matrix in spatial_matrix_train)
                # loss_t = sum(nuclear_norm(matrix) for matrix in temporal_matrix_train)
                # frob_loss_s = sum(frobenius_norm(spatial_matrix_fixed, matrix) for matrix in spatial_matrix_train)
                # frob_loss_t = sum(frobenius_norm(temporal_matrix_fixed, matrix) for matrix in temporal_matrix_train)
                loss_s = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in spatial_matrix_train]))
                loss_t = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in temporal_matrix_train]))
                frob_loss_s = torch.mean(torch.stack([frobenius_norm(spatial_matrix_fixed, matrix) for matrix in spatial_matrix_train]))
                frob_loss_t = torch.mean(torch.stack([frobenius_norm(temporal_matrix_fixed, matrix) for matrix in temporal_matrix_train]))
            add_loss = loss_s + loss_t + F.relu(frob_loss_s - args.delta) + F.relu(frob_loss_t - args.delta)
            
            task = record["task"]
            step = record["step"]
            answer = record["answer"]
            answers.append(answer)
            input_dict = {"task": task}

            answer_log_probs.append(asyncio.create_task(realized_graph.arun(input_dict,args.num_rounds,case=True)))

            add_losses.append(add_loss)
        
        print(22222222)
        raw_results = await asyncio.gather(*answer_log_probs)
        print(33333333)
        raw_answers, log_probs, all_answers = zip(*raw_results)
        loss_list: List[torch.Tensor] = []
        utilities: List[float] = []
        data = load_result(result_file)
        
        for task, answer, log_prob, add_loss, true_answer, all_answer in zip(current_batch, raw_answers, log_probs, add_losses, answers, all_answers):
            predict_answer = gsm_get_predict(answer[0])
            is_solved = float(predict_answer)==float(true_answer)
            # predict_answer = aqua_get_predict(answer[0])
            # is_solved = predict_answer==true_answer
            total_solved = total_solved + is_solved
            total_executed = total_executed + 1
            accuracy = total_solved/ total_executed
            utility = is_solved
            utilities.append(utility)
            single_loss = -log_prob * utility
            loss_list.append(single_loss+add_loss)
            updated_item = {
                "Question": task,
                "Answer": true_answer,
                "Step": step,
                "All_answers": all_answer,
                "Response": answer,
                "Attempt answer": predict_answer,
                "Solved": is_solved,
                "Total solved": total_solved,
                "Total executed": total_executed,
                "Accuracy": accuracy
            }
            data.append(updated_item)
            print(f"##########Final Log:{json.dumps(updated_item)}")
        with open(result_file, 'w',encoding='utf-8') as file:
            json.dump(data, file, indent=4)
        
        # total_loss = torch.mean(torch.stack(loss_list))
        # if args.optimized_spatial or args.optimized_temporal:
        #     optimizer.zero_grad()
        #     total_loss.backward()
        #     optimizer.step()
        # spatial_probs = torch.sigmoid(graph.spatial_logits)
        # temporal_probs = torch.sigmoid(graph.temporal_logits)
        
        print(f"Batch time {time.time() - start_ts:.3f}")
        print(f"Accuracy: {accuracy}")
        print("utilities:", utilities)
        # print("loss:", total_loss.item())
        # print("Spatial logits Grad:", graph.spatial_logits.grad)
        # print("Temporal logits Grad:", graph.spatial_logits.grad)
        # print("Spatial logits:", graph.spatial_logits)
        # print("Temporal logits:", graph.temporal_logits)
        # print("Spatial probs:", spatial_probs)
        # print("Temporal probs:", temporal_probs)
        # print("Spatial masks:", graph.spatial_masks)
        # print("Temporal logits:", graph.temporal_masks)
        
        print(f"Cost {Cost.instance().value}")
        print(f"PromptTokens {PromptTokens.instance().value}")
        print(f"CompletionTokens {CompletionTokens.instance().value}")


def get_kwargs(mode:Union[Literal['DirectAnswer'],Literal['FullConnected'],Literal['Random'],Literal['Chain'],Literal['Debate'],Literal['Layered'],Literal['Star']]
               ,N:int):
    initial_spatial_probability: float = 0.5
    fixed_spatial_masks:List[List[int]] = None
    initial_temporal_probability: float = 0.5
    fixed_temporal_masks:List[List[int]] = None
    node_kwargs = None
    
    def generate_layered_graph(N,layer_num=2):
        adj_matrix = [[0 for _ in range(N)] for _ in range(N)]
        base_size = N // layer_num
        remainder = N % layer_num
        layers = []
        for i in range(layer_num):
            size = base_size + (1 if i < remainder else 0)
            layers.extend([i] * size)
        # random.shuffle(layers)
        for i in range(N):
            current_layer = layers[i]
            for j in range(N):
                if layers[j] == current_layer + 1:
                    adj_matrix[i][j] = 1
        return adj_matrix
    
    def generate_star_graph(n):
        matrix = [[0] * n for _ in range(n)]
        for i in range(0, n):
            for j in range(i+1,n):
                matrix[i][j] = 1
        return matrix
    
    if mode=='DirectAnswer':
        fixed_spatial_masks = [[0]]
        fixed_temporal_masks = [[0]]
        node_kwargs = [{'role':'Math Solver'}]
    elif mode=='FullConnected':
        fixed_spatial_masks = [[1 if i!=j else 0 for i in range(N)] for j in range(N)]
        fixed_temporal_masks = [[1 for _ in range(N)] for _ in range(N)]
    elif mode=='Random':
        fixed_spatial_masks = [[random.randint(0, 1)  if i!=j else 0 for i in range(N)] for j in range(N)]
        fixed_temporal_masks = [[random.randint(0, 1) for _ in range(N)] for _ in range(N)]
    elif mode=='Chain':
        fixed_spatial_masks = [[1 if i==j+1 else 0 for i in range(N)] for j in range(N)]
        fixed_temporal_masks = [[1 if i==0 and j==N-1 else 0 for i in range(N)] for j in range(N)]
    elif mode == 'Debate':
        fixed_spatial_masks = [[0 for i in range(N)] for j in range(N)]
        fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)]
    elif mode == 'Layered':
        fixed_spatial_masks = generate_layered_graph(N)
        fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)]
    elif mode == 'Star':
        fixed_spatial_masks = generate_star_graph(N)
        fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)]
    
    return {"initial_spatial_probability": initial_spatial_probability,
            "fixed_spatial_masks": fixed_spatial_masks,
            "initial_temporal_probability": initial_temporal_probability,
            "fixed_temporal_masks": fixed_temporal_masks,
            "node_kwargs":node_kwargs}    

if __name__ == '__main__':
    asyncio.run(main())
